#!/usr/bin/env python3
# %%
from __future__ import annotations

import re
from collections import OrderedDict

import numpy as np

from dpdata.periodic_table import Element


class QuipGapxyzSystems:
    """deal with QuipGapxyzFile."""

    def __init__(self, file_name):
        self.file_object = open(file_name)
        self.block_generator = self.get_block_generator()

    def __iter__(self):
        return self

    def __next__(self):
        return self.handle_single_xyz_frame(next(self.block_generator))

    def __del__(self):
        self.file_object.close()

    def get_block_generator(self):
        p3 = re.compile(r"^\s*(\d+)\s*")
        while True:
            line = self.file_object.readline()
            if not line:
                break
            if p3.match(line):
                atom_num = int(p3.match(line).group(1))
                lines = []
                lines.append(line)
                for ii in range(atom_num + 1):
                    lines.append(self.file_object.readline())
                if not lines[-1]:
                    raise RuntimeError(
                        f"this xyz file may lack of lines, should be {atom_num + 2};lines:{lines}"
                    )
                yield lines

    @staticmethod
    def handle_single_xyz_frame(lines):
        atom_num = int(lines[0].strip("\n").strip())
        if len(lines) != atom_num + 2:
            raise RuntimeError(
                f"format error, atom_num=={atom_num}, {len(lines)}!=atom_num+2"
            )
        data_format_line = lines[1].strip("\n").strip() + " "
        field_value_pattern = re.compile(
            r"(?P<key>\S+)=(?P<quote>[\'\"]?)(?P<value>.*?)(?P=quote)\s+"
        )
        prop_pattern = re.compile(
            r"(?P<key>\w+?):(?P<datatype>[a-zA-Z]):(?P<value>\d+)"
        )

        data_format_list = [
            kv_dict.groupdict()
            for kv_dict in field_value_pattern.finditer(data_format_line)
        ]
        field_dict = {}
        for item in data_format_list:
            field_dict[item["key"]] = item["value"]

        Properties = field_dict["Properties"]
        prop_list = [
            kv_dict.groupdict() for kv_dict in prop_pattern.finditer(Properties)
        ]

        data_lines = []
        for line in lines[2:]:
            data_lines.append(list(filter(bool, line.strip().split())))
        data_array = np.array(data_lines)
        used_colomn = 0

        type_array = None
        coords_array = None
        Z_array = None
        force_array = None
        virials = None
        for kv_dict in prop_list:
            if kv_dict["key"] == "species":
                if kv_dict["datatype"] != "S":
                    raise RuntimeError(
                        "datatype for species must be 'S' instead of {}".format(
                            kv_dict["datatype"]
                        )
                    )
                field_length = int(kv_dict["value"])
                type_array = data_array[
                    :, used_colomn : used_colomn + field_length
                ].flatten()
                used_colomn += field_length
                continue
            elif kv_dict["key"] == "pos":
                if kv_dict["datatype"] != "R":
                    raise RuntimeError(
                        "datatype for pos must be 'R' instead of {}".format(
                            kv_dict["datatype"]
                        )
                    )
                field_length = int(kv_dict["value"])
                coords_array = data_array[:, used_colomn : used_colomn + field_length]
                used_colomn += field_length
                continue
            elif kv_dict["key"] == "Z":
                if kv_dict["datatype"] != "I":
                    raise RuntimeError(
                        "datatype for pos must be 'R' instead of {}".format(
                            kv_dict["datatype"]
                        )
                    )
                field_length = int(kv_dict["value"])
                Z_array = data_array[
                    :, used_colomn : used_colomn + field_length
                ].flatten()
                used_colomn += field_length
                continue
            elif kv_dict["key"] == "force":
                if kv_dict["datatype"] != "R":
                    raise RuntimeError(
                        "datatype for pos must be 'R' instead of {}".format(
                            kv_dict["datatype"]
                        )
                    )
                field_length = int(kv_dict["value"])
                force_array = data_array[:, used_colomn : used_colomn + field_length]
                used_colomn += field_length
                continue
            else:
                raise RuntimeError("unknown field {}".format(kv_dict["key"]))

        type_num_dict = OrderedDict()
        atom_type_list = []
        type_map = {}
        temp_atom_max_index = 0
        if type_array is None:
            raise RuntimeError("type_array can't be None type, check .xyz file")
        for ii in type_array:
            if ii not in type_map:
                type_map[ii] = temp_atom_max_index
                temp_atom_max_index += 1
                temp_atom_index = type_map[ii]
                atom_type_list.append(temp_atom_index)
                type_num_dict[ii] = 1
            else:
                temp_atom_index = type_map[ii]
                atom_type_list.append(temp_atom_index)
                type_num_dict[ii] += 1
        type_num_list = []
        for atom_type, atom_num in type_num_dict.items():
            type_num_list.append((atom_type, atom_num))
        type_num_array = np.array(type_num_list)
        if field_dict.get("virial", None):
            virials = np.array(
                [
                    np.array(
                        list(filter(bool, field_dict["virial"].split(" ")))
                    ).reshape(3, 3)
                ]
            ).astype(np.float64)
        else:
            virials = None

        info_dict = {}
        info_dict["atom_names"] = list(type_num_array[:, 0])
        info_dict["atom_numbs"] = list(type_num_array[:, 1].astype(int))
        info_dict["atom_types"] = np.array(atom_type_list).astype(int)
        info_dict["cells"] = np.array(
            [
                np.array(list(filter(bool, field_dict["Lattice"].split(" ")))).reshape(
                    3, 3
                )
            ]
        ).astype(np.float64)
        info_dict["coords"] = np.array([coords_array]).astype(np.float64)
        info_dict["energies"] = np.array([field_dict["energy"]]).astype(np.float64)
        info_dict["forces"] = np.array([force_array]).astype(np.float64)
        if virials is not None:
            info_dict["virials"] = virials
        info_dict["orig"] = np.zeros(3)
        return info_dict


def format_single_frame(data, frame_idx):
    """Format a single frame of system data into QUIP/GAP XYZ format lines.

    Parameters
    ----------
    data : dict
        system data
    frame_idx : int
        frame index

    Returns
    -------
    list[str]
        lines for the frame
    """
    # Number of atoms
    natoms = len(data["atom_types"])

    # Build header line with metadata
    header_parts = []

    # Energy
    energy = data["energies"][frame_idx]
    header_parts.append(f"energy={energy:.12e}")

    # Virial (if present)
    if "virials" in data:
        virial = data["virials"][frame_idx]
        virial_str = "    ".join(f"{v:.12e}" for v in virial.flatten())
        header_parts.append(f'virial="{virial_str}"')

    # Lattice
    cell = data["cells"][frame_idx]
    lattice_str = "   ".join(f"{c:.12e}" for c in cell.flatten())
    header_parts.append(f'Lattice="{lattice_str}"')

    # Properties
    header_parts.append("Properties=species:S:1:pos:R:3:Z:I:1:force:R:3")

    header_line = "    ".join(header_parts)

    # Format atom lines
    atom_lines = []
    coords = data["coords"][frame_idx]
    forces = data["forces"][frame_idx]
    atom_names = np.array(data["atom_names"])
    atom_types = data["atom_types"]

    for i in range(natoms):
        atom_type_idx = atom_types[i]
        species = atom_names[atom_type_idx]
        x, y, z = coords[i]
        fx, fy, fz = forces[i]
        atomic_number = Element(species).Z

        atom_line = f"{species}    {x:.11e}   {y:.11e}   {z:.11e}   {atomic_number}    {fx:.11e}  {fy:.11e}   {fz:.11e}"
        atom_lines.append(atom_line)

    # Combine all lines for this frame
    frame_lines = [str(natoms), header_line] + atom_lines
    return frame_lines
